#!/usr/bin/env python3
import json, re, numpy as np, pandas as pd, math

PLATEAU = "outputs/lensing_plateau.csv"
WINDOWS = "outputs/windows.json"
STACKS  = "data/prestacked_stacks.csv"
OUT     = "outputs/size_regression.json"  # overwrite with QC version

def rg_mid(lbl:str)->float:
    s=(lbl or "").strip().replace("—","-").replace("–","-")
    m=re.match(r"\s*([0-9.]+)\s*-\s*([0-9.]+)\s*", s)
    return 0.5*(float(m.group(1))+float(m.group(2))) if m else float("nan")

def load_windows(p):
    W=json.load(open(p))
    out={}
    if isinstance(W, dict):
        for sid,v in W.items():
            if isinstance(v,dict) and v.get("i0") is not None:
                out[sid]=("idx",int(v["i0"]),int(v["i1"]))
            elif isinstance(v,dict) and v.get("b_min") is not None:
                out[sid]=("b",float(v["b_min"]),float(v["b_max"]))
    elif isinstance(W, list):
        for v in W:
            sid=v.get("stack_id") or v.get("id")
            if not sid: continue
            if v.get("i0") is not None:
                out[sid]=("idx",int(v["i0"]),int(v["i1"]))
            elif v.get("b_min") is not None:
                out[sid]=("b",float(v["b_min"]),float(v["b_max"]))
    return out

def wls_slope(x,y,w):
    x,y,w=np.asarray(x,float),np.asarray(y,float),np.asarray(w,float)
    W=w.sum()
    if not np.isfinite(W) or W<=0: return float("nan")
    xb=(w*x).sum()/W; yb=(w*y).sum()/W
    num=(w*(x-xb)*(y-yb)).sum()
    den=(w*(x-xb)**2).sum()
    return float(num/den) if den>0 else float("nan")

def main():
    P=pd.read_csv(PLATEAU)
    S=pd.read_csv(STACKS).sort_values(["stack_id","b"])
    win=load_windows(WINDOWS)
    ok=(P["claimable"].astype(str).str.lower()=="true")
    P=P.loc[ok].copy()

    # compute per-stack window total weight from stacks + windows
    Wtot=[]
    for sid,g in S.groupby("stack_id"):
        if sid not in win: 
            Wtot.append((sid, np.nan)); continue
        mode,a,b=win[sid]
        gg=g.reset_index(drop=True) if mode=="idx" else g
        if mode=="idx":
            i0,i1=max(0,int(a)),min(len(gg),int(b))
            wsum=float(np.nansum(gg["weight"].iloc[i0:i1].to_numpy()))
        else:
            bmin,bmax=float(a),float(b)
            wsum=float(np.nansum(gg[(gg["b"]>=bmin)&(gg["b"]<=bmax)]["weight"].to_numpy()))
        Wtot.append((sid, wsum))
    Wtot=pd.DataFrame(Wtot, columns=["stack_id","window_weight"])
    P=P.merge(Wtot, on="stack_id", how="left")

    # QC: keep stacks with window_weight >= median(window_weight)
    med=float(np.nanmedian(P["window_weight"]))
    P_qc=P[(P["window_weight"]>=med) & pd.to_numeric(P["A_theta"],errors="coerce").notna()].copy()

    # prepare regression table
    P_qc["RG_mid"]=P_qc["R_G_bin"].apply(rg_mid)
    P_qc["rmse_flat"]=pd.to_numeric(P_qc["rmse_flat"],errors="coerce")
    P_qc=P_qc.replace([np.inf,-np.inf],np.nan).dropna(subset=["RG_mid","A_theta","rmse_flat"])

    out={}
    rng=np.random.default_rng(42)
    B=4000
    eps=1e-8

    for ms,g in P_qc.groupby("Mstar_bin"):
        x=g["RG_mid"].to_numpy(float)
        y=pd.to_numeric(g["A_theta"],errors="coerce").to_numpy(float)
        w=1.0/(g["rmse_flat"].to_numpy(float)**2 + eps)   # flatter ⇒ heavier
        n=len(g)
        if n<3:
            out[ms]={"n_stacks":int(n),"slope_Atheta_vs_RG":float("nan"),
                     "CI_16":float("nan"),"CI_84":float("nan"),"weighted_qc":True,
                     "kept":int(n),"kept_fraction":float(n/max(1,len(P[P['Mstar_bin']==ms])))}
            continue
        mhat=wls_slope(x,y,w)
        p=w/w.sum()
        boots=np.empty(B,float)
        for i in range(B):
            idx=rng.choice(n, size=n, replace=True, p=p)
            boots[i]=wls_slope(x[idx],y[idx],w[idx])
        boots=boots[np.isfinite(boots)]
        lo,hi=np.percentile(boots,[16,84]) if boots.size else (float("nan"),float("nan"))
        out[ms]={"n_stacks":int(n),"slope_Atheta_vs_RG":float(mhat),
                 "CI_16":float(lo),"CI_84":float(hi),"weighted_qc":True,
                 "kept":int(n),"kept_fraction":float(n/max(1,len(P[P['Mstar_bin']==ms])))}

    json.dump(out, open(OUT,"w"), indent=2)
    print(f"Wrote {OUT} (weighted + QC). Median window weight = {med:.3f}")
    print(json.dumps(out, indent=2))
if __name__=="__main__":
    main()
